  
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import copy

from torch.autograd import Variable

import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net1conv1fcXL(nn.Module):
    def __init__(self,ch_input,nout):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 5,bias=False)
        self.fc1 = nn.Linear(4608, nout,bias=False)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x, do_masks):
        x = self.pool(F.tanh(self.conv1(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.softmax(self.fc1(x))
        return x
    
class Net1conv1fcXL_cif(nn.Module):
    def __init__(self,ch_input,nout):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 5,bias=False)
        self.fc1 = nn.Linear(6272, nout,bias=False)
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x, do_masks):
        x = self.pool(F.tanh(self.conv1(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.softmax(self.fc1(x))
        return x

class Net2conv1fcXL_cif(nn.Module):
    def __init__(self, ch_input, nout):
        super().__init__()
        self.conv1 = nn.Conv2d(ch_input, 32, kernel_size=5, bias=False)          # -> [B, 32, 28, 28]
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)     # -> [B, 64, 14, 14]
        self.pool = nn.MaxPool2d(2, 2)

        self.fc1 = nn.Linear(64 * 7 * 7, nout, bias=False)

    def forward(self, x, do_masks=None):
        x = self.pool(F.tanh(self.conv1(x)))   # -> [B, 32, 14, 14]
        x = self.pool(F.tanh(self.conv2(x)))   # -> [B, 64, 7, 7]
        x = torch.flatten(x, 1)                # -> [B, 3136]
        x = F.softmax(self.fc1(x), dim=1)      # -> [B, nout]
        return x
    
    


def compute_delta_w_conv(inp, out_diff, w_shape, stride=1, sqrt=False, plot=False, plot2d=False, device='cpu'):
    delta_w = torch.zeros(w_shape, device=device)
    ch_out = w_shape[0]
    size_out = out_diff.shape[-1]
    ch_in = w_shape[1]
    size_in = inp.shape[-1]
    ks = w_shape[2]
    bs = out_diff.shape[0]
    cnt = 0

    if plot:
        fig, axs = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False)
        fig2, axs2 = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False)
    if plot2d:
        figb, axsb = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False)
        fig2b, axs2b = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False)

    for r in range(size_out):
        for c in range(size_out):
            inp_r_start = stride * r
            inp_r_end = stride * r + ks
            inp_c_start = stride * c
            inp_c_end = stride * c + ks
            this_out_diff = out_diff[:, :, r, c]
            this_inp = inp[:, :, inp_r_start:inp_r_end, inp_c_start:inp_c_end]

            partial = ev(this_out_diff, this_inp, bs, ch_in, ch_out, ks, device=device).reshape_as(delta_w)
            delta_w += partial

            if plot:
                axs[cnt].imshow(partial.detach().cpu().numpy()[0, 0])
                axs2[cnt].imshow(delta_w.detach().cpu().numpy()[0, 0])
                if plot2d:
                    axsb[cnt].imshow(partial.detach().cpu().numpy()[1, 0])
                    axs2b[cnt].imshow(delta_w.detach().cpu().numpy()[1, 0])
            cnt += 1

    delta_w *= 1. / np.sqrt(cnt) if sqrt else 1. / cnt
    return delta_w

def ev(this_out_diff, this_inp, bs, chin, chout, ks, device='cpu'):
    prod_mul = torch.mul(
        this_out_diff.reshape(bs, chout, 1, 1, 1),
        this_inp.reshape(bs, 1, chin, ks, ks)
    )
    return torch.mean(prod_mul, dim=0).to(device)
